{
  "cells": [
    {
      "cell_type": "raw",
      "metadata": {},
      "source": [
        "---\n",
        "title: Matrix multiplication in Mojo\n",
        "description: Learn how to leverage Mojo's various functions to write a high-performance matmul.\n",
        "website:\n",
        "  open-graph:\n",
        "    image: /static/images/mojo-social-card.png\n",
        "  twitter-card:\n",
        "    image: /static/images/mojo-social-card.png\n",
        "---"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "[//]: # REMOVE_FOR_WEBSITE\n",
        "*Copyright 2023 Modular, Inc: Licensed under the Apache License v2.0 with LLVM Exceptions.*"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "[//]: # REMOVE_FOR_WEBSITE\n",
        "# Matrix multiplication in Mojo"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "This notebook describes how to write a matrix multiplication (matmul) algorithm in Mojo. We will start with a pure Python implementation, transition to a naive implementation that is essentially a copy of the Python one, then add types, then continue the optimizations by vectorizing, tiling, and parallelizing the implementation."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "First, let's define matrix multiplication. Given two dense matrices $A$ and $B$ of dimensions $M\\times K$ and $K\\times N$ respectively, we want to compute their dot product $C = A . B$ (also known as matmul). The dot product $C += A . B$ is defined by"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "$$C_{i,j} += \\sum_{k \\in [0 \\cdots K)} A_{i,k} B_{k,j}$$"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-block alert-success alert--secondary\">\n",
        "\n",
        "Please take look at our [blog](https://www.modular.com/blog/ais-compute-fragmentation-what-matrix-multiplication-teaches-us) post on matmul and why it is important for ML and DL workloads.\n",
        "\n",
        "</div>"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The format of this notebook is to start with an implementation which is identical to that of Python (effectively renaming the file extension), then look at how adding types to the implementation helps performance before extending the implementation by leveraging the vectorization and parallelization capabilities available on modern hardware. Throughout the execution, we report the GFlops achieved."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "[//]: # REMOVE_FOR_WEBSITE\n",
        "<div class=\"alert alert-block alert-info alert--secondary\">\n",
        "<b>Note:</b> Mojo Playground is designed only for testing the Mojo language.\n",
        "The cloud environment is not always stable and performance varies, so it is not\n",
        "an appropriate environment for performance benchmarking. However, we believe it\n",
        "can still demonstrate the magnitude of performance gains provided by Mojo. For\n",
        "more information about the compute power in the Mojo Playground, see the <a\n",
        "href=\"https://docs.modular.com/mojo/faq.html#mojo-playground\">Mojo FAQ</a>.\n",
        "</div>"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Python Implementation"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's first implement matmul in Python directly from the definition."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "%%python\n",
        "def matmul_python(C, A, B):\n",
        "    for m in range(C.rows):\n",
        "        for k in range(A.cols):\n",
        "            for n in range(C.cols):\n",
        "                C[m, n] += A[m, k] * B[k, n]"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Let's benchmark our implementation using 128 by 128 square matrices and report the achieved GFLops."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Install numpy if it's not already:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {},
      "outputs": [],
      "source": [
        "%%python\n",
        "from importlib.util import find_spec\n",
        "import shutil\n",
        "import subprocess\n",
        "\n",
        "fix = \"\"\"\n",
        "-------------------------------------------------------------------------\n",
        "fix following the steps here:\n",
        "    https://github.com/modularml/mojo/issues/1085#issuecomment-1771403719\n",
        "-------------------------------------------------------------------------\n",
        "\"\"\"\n",
        "\n",
        "def install_if_missing(name: str):\n",
        "    if find_spec(name):\n",
        "        return\n",
        "\n",
        "    print(f\"{name} not found, installing...\")\n",
        "    try:\n",
        "        if shutil.which('python3'): python = \"python3\"\n",
        "        elif shutil.which('python'): python = \"python\"\n",
        "        else: raise (\"python not on path\" + fix)\n",
        "        subprocess.check_call([python, \"-m\", \"pip\", \"install\", name])\n",
        "    except:\n",
        "        raise ImportError(f\"{name} not found\" + fix)\n",
        "\n",
        "install_if_missing(\"numpy\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 3,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "%%python\n",
        "from timeit import timeit\n",
        "import numpy as np\n",
        "\n",
        "class Matrix:\n",
        "    def __init__(self, value, rows, cols):\n",
        "        self.value = value\n",
        "        self.rows = rows\n",
        "        self.cols = cols\n",
        "\n",
        "    def __getitem__(self, idxs):\n",
        "        return self.value[idxs[0]][idxs[1]]\n",
        "\n",
        "    def __setitem__(self, idxs, value):\n",
        "        self.value[idxs[0]][idxs[1]] = value\n",
        "\n",
        "def benchmark_matmul_python(M, N, K):\n",
        "    A = Matrix(list(np.random.rand(M, K)), M, K)\n",
        "    B = Matrix(list(np.random.rand(K, N)), K, N)\n",
        "    C = Matrix(list(np.zeros((M, N))), M, N)\n",
        "    secs = timeit(lambda: matmul_python(C, A, B), number=2)/2\n",
        "    gflops = ((2*M*N*K)/secs) / 1e9\n",
        "    print(gflops, \"GFLOP/s\")\n",
        "    return gflops"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 4,
      "metadata": {
        "tags": []
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0.0022564560694965834 GFLOP/s\n"
          ]
        }
      ],
      "source": [
        "python_gflops = benchmark_matmul_python(128, 128, 128).to_float64()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Importing the Python implementation to Mojo"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Using Mojo is as simple as Python. First, let's include that modules from the Mojo stdlib that we are going to use:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 5,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "#|code-fold: true\n",
        "#|code-summary: \"Import utilities and define `Matrix` (click to show/hide)\"\n",
        "\n",
        "import benchmark\n",
        "from memory import memset_zero\n",
        "from random import rand, random_float64"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Then, we can copy and paste our Python code. Mojo is a superset of Python, so the same Python code will run as Mojo code"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 6,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "# This exactly the same Python implementation, but is infact Mojo code!\n",
        "def matmul_untyped(C, A, B):\n",
        "    for m in range(C.rows):\n",
        "        for k in range(A.cols):\n",
        "            for n in range(C.cols):\n",
        "                C[m, n] += A[m, k] * B[k, n]"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can then benchmark the implementation. As before we use a 128 by 128 matrix"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "fn matrix_getitem(self: object, i: object) raises -> object:\n",
        "    return self.value[i]\n",
        "\n",
        "\n",
        "fn matrix_setitem(self: object, i: object, value: object) raises -> object:\n",
        "    self.value[i] = value\n",
        "    return None\n",
        "\n",
        "\n",
        "fn matrix_append(self: object, value: object) raises -> object:\n",
        "    self.value.append(value)\n",
        "    return None\n",
        "\n",
        "\n",
        "fn matrix_init(rows: Int, cols: Int) raises -> object:\n",
        "    var value = object([])\n",
        "    return object(\n",
        "        Attr(\"value\", value), Attr(\"__getitem__\", matrix_getitem), Attr(\"__setitem__\", matrix_setitem),\n",
        "        Attr(\"rows\", rows), Attr(\"cols\", cols), Attr(\"append\", matrix_append),\n",
        "    )\n",
        "\n",
        "def benchmark_matmul_untyped(M: Int, N: Int, K: Int, python_gflops: Float64):\n",
        "    C = matrix_init(M, N)\n",
        "    A = matrix_init(M, K)\n",
        "    B = matrix_init(K, N)\n",
        "    for i in range(M):\n",
        "        c_row = object([])\n",
        "        b_row = object([])\n",
        "        a_row = object([])\n",
        "        for j in range(N):\n",
        "            c_row.append(0.0)\n",
        "            b_row.append(random_float64(-5, 5))\n",
        "            a_row.append(random_float64(-5, 5))\n",
        "        C.append(c_row)\n",
        "        B.append(b_row)\n",
        "        A.append(a_row)\n",
        "\n",
        "    @parameter\n",
        "    fn test_fn():\n",
        "        try:\n",
        "            _ = matmul_untyped(C, A, B)\n",
        "        except:\n",
        "            pass\n",
        "\n",
        "    var secs = benchmark.run[test_fn](max_runtime_secs=0.5).mean()\n",
        "    _ = (A, B, C)\n",
        "    var gflops = ((2*M*N*K)/secs) / 1e9\n",
        "    var speedup : Float64 = gflops / python_gflops\n",
        "    print(gflops, \"GFLOP/s, a\", speedup, \"x speedup over Python\")"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "tags": []
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "0.019322323885887414 GFLOP/s, a 8.563128769530282 x speedup over Python\n"
          ]
        }
      ],
      "source": [
        "benchmark_matmul_untyped(128, 128, 128, python_gflops)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Note the huge speedup with no effort that we have gotten."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Adding types to the Python implementation"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The above program, while achieving better performance than Python, is still not the best we can get from Mojo. If we tell Mojo the types of the inputs, it can optimize much of the code away and reduce dispatching costs (unlike Python, which only uses types for type checking, Mojo exploits type info for performance optimizations as well)."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "To do that, let's first define a `Matrix` struct. The `Matrix` struct contains a data pointer along with size fields. While the `Matrix` struct can be parametrized on any data type, here we set the data type to be Float32 for conciseness."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 9,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "alias type = DType.float32\n",
        "\n",
        "struct Matrix[rows: Int, cols: Int]:\n",
        "    var data: DTypePointer[type]\n",
        "\n",
        "    # Initialize zeroeing all values\n",
        "    fn __init__(inout self):\n",
        "        self.data = DTypePointer[type].alloc(rows * cols)\n",
        "        memset_zero(self.data, rows * cols)\n",
        "\n",
        "    # Initialize taking a pointer, don't set any elements\n",
        "    fn __init__(inout self, data: DTypePointer[type]):\n",
        "        self.data = data\n",
        "\n",
        "    # Initialize with random values\n",
        "    @staticmethod\n",
        "    fn rand() -> Self:\n",
        "        var data = DTypePointer[type].alloc(rows * cols)\n",
        "        rand(data, rows * cols)\n",
        "        return Self(data)\n",
        "\n",
        "    fn __getitem__(self, y: Int, x: Int) -> Scalar[type]:\n",
        "        return self.load[1](y, x)\n",
        "\n",
        "    fn __setitem__(self, y: Int, x: Int, val: Scalar[type]):\n",
        "        self.store[1](y, x, val)\n",
        "\n",
        "    fn load[nelts: Int](self, y: Int, x: Int) -> SIMD[type, nelts]:\n",
        "        return self.data.load[width=nelts](y * self.cols + x)\n",
        "\n",
        "    fn store[nelts: Int](self, y: Int, x: Int, val: SIMD[type, nelts]):\n",
        "        return self.data.store[width=nelts](y * self.cols + x, val)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "<div class=\"alert alert-block alert-success alert--secondary\">\n",
        "\n",
        "Note that we implement `getitem` and `setitem` in terms of `load` and `store`. For the naive implementation of matmul it does not make a difference, but we will utilize this later in a more optimized vectorized version of matmul.\n",
        "\n",
        "</div>"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "With the above `Matrix` type we can effectively copy and paste the Python implementation and just add type annotations:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 10,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "# Note that C, A, and B have types.\n",
        "fn matmul_naive(C: Matrix, A: Matrix, B: Matrix):\n",
        "    for m in range(C.rows):\n",
        "        for k in range(A.cols):\n",
        "            for n in range(C.cols):\n",
        "                C[m, n] += A[m, k] * B[k, n]"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We are going to benchmark the implementations as we improve, so let's write a helper function that will do that for us: "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 11,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "alias M = 1024\n",
        "alias N = 1024\n",
        "alias K = 1024\n",
        "\n",
        "@always_inline\n",
        "fn bench[\n",
        "    func: fn (Matrix, Matrix, Matrix) -> None](base_gflops: Float64):\n",
        "    var C = Matrix[M, N]()\n",
        "    var A = Matrix[M, K].rand()\n",
        "    var B = Matrix[K, N].rand()\n",
        "\n",
        "    @always_inline\n",
        "    @parameter\n",
        "    fn test_fn():\n",
        "        _ = func(C, A, B)\n",
        "\n",
        "    var secs = benchmark.run[test_fn](max_runtime_secs=1).mean()\n",
        "\n",
        "    A.data.free()\n",
        "    B.data.free()\n",
        "    C.data.free()\n",
        "\n",
        "    var gflops = ((2 * M * N * K) / secs) / 1e9\n",
        "    var speedup: Float64 = gflops / base_gflops\n",
        "\n",
        "    print(gflops, \"GFLOP/s, a\", speedup, \"x speedup over Python\")"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Benchmarking shows significant speedups. We increase the size of the matrix to 512 by 512, since Mojo is much faster than Python."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 12,
      "metadata": {
        "tags": []
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "5.3335685138294373 GFLOP/s, a 2363.6926000599515 x speedup over Python\n"
          ]
        }
      ],
      "source": [
        "bench[matmul_naive](python_gflops)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Adding type annotations gives a huge improvement compared to the original untyped version."
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Vectorizing the inner most loop"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can do better than the above implementation by utilizing vector instructions. Rather than assuming a vector width, we query the simd width of the specified dtype using `simdwidthof`. This makes our code portable as we transition to other hardware. Leveraging SIMD instructions is as easy as:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 13,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "# simdwidthof = number of float32 elements that fit into a single SIMD register\n",
        "# using a 2x multiplier allows some SIMD operations to run in the same cycle\n",
        "alias nelts = simdwidthof[DType.float32]() * 2\n",
        "\n",
        "fn matmul_vectorized_0(C: Matrix, A: Matrix, B: Matrix):\n",
        "    for m in range(C.rows):\n",
        "        for k in range(A.cols):\n",
        "            for nv in range(0, C.cols, nelts):\n",
        "                C.store(m, nv, C.load[nelts](m, nv) + A[m, k] * B.load[nelts](k, nv))\n",
        "\n",
        "            # Handle remaining elements with scalars.\n",
        "            for n in range(nelts * (C.cols // nelts), C.cols):\n",
        "                C[m, n] += A[m, k] * B[k, n]"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can benchmark the above implementation. Note that many compilers can detect naive loops and perform optimizations on them. Mojo, however, allows you to be explicit and precisely control what optimizations are applied."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 14,
      "metadata": {
        "tags": []
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "27.943656514151261 GFLOP/s, a 12383.869064371152 x speedup over Python\n"
          ]
        }
      ],
      "source": [
        "bench[matmul_vectorized_0](python_gflops)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Vectorization is a common optimization, and Mojo provides a higher-order function that performs vectorization for you. The `vectorize` function takes a vector width and a function which is parametric on the vector width and is going to be evaluated in a vectorized manner."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 15,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "# Simplify the code by using the builtin vectorize function\n",
        "from algorithm import vectorize\n",
        "\n",
        "fn matmul_vectorized_1(C: Matrix, A: Matrix, B: Matrix):\n",
        "    for m in range(C.rows):\n",
        "        for k in range(A.cols):\n",
        "            @parameter\n",
        "            fn dot[nelts: Int](n: Int):\n",
        "                C.store(m, n, C.load[nelts](m, n) + A[m, k] * B.load[nelts](k, n))\n",
        "            vectorize[dot, nelts, size = C.cols]()"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "There is only a slight difference in terms of performance between the two implementations:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 16,
      "metadata": {
        "tags": []
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "29.94507372793699 GFLOP/s, a 13270.842775422503 x speedup over Python\n"
          ]
        }
      ],
      "source": [
        "bench[matmul_vectorized_1](python_gflops)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Parallelizing Matmul"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "With Mojo we can easily run code in parallel with the `parallelize` function.\n",
        "\n",
        "Let's modify our matmul implementation and make it multi-threaded (for simplicity, we only `parallelize` on the M dimension).\n",
        "\n",
        "In `parallelize` below we're overpartitioning by distributing the work more evenly among processors. This ensures they all have something to work on even if some tasks finish before others, or some processors are stragglers. Intel and Apple now have separate performance and efficiency cores and this mitigates the problems that can cause."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "# Parallelize the code by using the builtin parallelize function\n",
        "from algorithm import parallelize\n",
        "\n",
        "fn matmul_parallelized(C: Matrix, A: Matrix, B: Matrix):\n",
        "    @parameter\n",
        "    fn calc_row(m: Int):\n",
        "        for k in range(A.cols):\n",
        "            @parameter\n",
        "            fn dot[nelts : Int](n : Int):\n",
        "                C.store[nelts](m,n, C.load[nelts](m,n) + A[m,k] * B.load[nelts](k,n))\n",
        "            vectorize[dot, nelts, size = C.cols]()\n",
        "    parallelize[calc_row](C.rows, C.rows)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "We can benchmark the parallel matmul implementation."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {
        "tags": []
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "860.6608444221323 GFLOP/s, a 381421.49366734456 x speedup over Python\n"
          ]
        }
      ],
      "source": [
        "bench[matmul_parallelized](python_gflops)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Tiling Matmul"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Tiling is an optimization performed for matmul to increase cache locality. The idea is to keep sub-matrices resident in the cache and increase the reuse. The tile function itself can be written in Mojo as:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "from algorithm import Static2DTileUnitFunc as Tile2DFunc\n",
        "\n",
        "# Perform 2D tiling on the iteration space defined by end_x and end_y.\n",
        "fn tile[tiled_fn: Tile2DFunc, tile_x: Int, tile_y: Int](end_x: Int, end_y: Int):\n",
        "    # Note: this assumes that ends are multiples of the tiles.\n",
        "    for y in range(0, end_y, tile_y):\n",
        "        for x in range(0, end_x, tile_x):\n",
        "            tiled_fn[tile_x, tile_y](x, y)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "The above will perform 2 dimensional tiling over a 2D iteration space defined to be between $([0, end_x], [0, end_y])$. Once we define it above, we can use it within our matmul kernel. For simplicity we choose `4` as the tile height and since we also want to vectorize we use `4 * nelts` as the tile width (since we vectorize on the columns)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 20,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "# Use the above tile function to perform tiled matmul.\n",
        "fn matmul_tiled_parallelized(C: Matrix, A: Matrix, B: Matrix):\n",
        "    @parameter\n",
        "    fn calc_row(m: Int):\n",
        "        @parameter\n",
        "        fn calc_tile[tile_x: Int, tile_y: Int](x: Int, y: Int):\n",
        "            for k in range(y, y + tile_y):\n",
        "                @parameter\n",
        "                fn dot[nelts: Int](n: Int):\n",
        "                    C.store(m, n + x, C.load[nelts](m, n + x) + A[m, k] * B.load[nelts](k, n + x))\n",
        "                vectorize[dot, nelts, size = tile_x]()\n",
        "\n",
        "        # We hardcode the tile factor to be 4.\n",
        "        alias tile_size = 4\n",
        "        tile[calc_tile, nelts * tile_size, tile_size](A.cols, C.cols)\n",
        "\n",
        "    parallelize[calc_row](C.rows, C.rows)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Again, we can benchmark the tiled parallel matmul implementation:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 21,
      "metadata": {
        "tags": []
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "898.79112716147336 GFLOP/s, a 398319.79860436375 x speedup over Python\n"
          ]
        }
      ],
      "source": [
        "bench[matmul_tiled_parallelized](python_gflops)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "One source of overhead in the above implementation is the fact that the we are not unrolling the loops introduced by vectorize of the dot function. We can do that via the `vectorize_unroll` higher-order function in Mojo:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 22,
      "metadata": {
        "tags": []
      },
      "outputs": [],
      "source": [
        "# Unroll the vectorized loop by a constant factor.\n",
        "fn matmul_tiled_unrolled_parallelized(C: Matrix, A: Matrix, B: Matrix):\n",
        "    @parameter\n",
        "    fn calc_row(m: Int):\n",
        "        @parameter\n",
        "        fn calc_tile[tile_x: Int, tile_y: Int](x: Int, y: Int):\n",
        "            for k in range(y, y + tile_y):\n",
        "                @parameter\n",
        "                fn dot[nelts: Int](n: Int):\n",
        "                    C.store(m, n + x, C.load[nelts](m, n + x) + A[m, k] * B.load[nelts](k, n + x))\n",
        "\n",
        "                # Vectorize by nelts and unroll by tile_x/nelts\n",
        "                # Here unroll factor is 4\n",
        "                alias unroll_factor = tile_x // nelts\n",
        "                vectorize[dot, nelts, size=tile_x, unroll_factor=unroll_factor]()\n",
        "\n",
        "        alias tile_size = 4\n",
        "        tile[calc_tile, nelts * tile_size, tile_size](A.cols, C.cols)\n",
        "\n",
        "    parallelize[calc_row](C.rows, C.rows)"
      ]
    },
    {
      "attachments": {},
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "Again, we can benchmark the new tiled parallel matmul implementation with unrolled and vectorized inner loop:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 23,
      "metadata": {},
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "1026.9730406876936 GFLOP/s, a 455126.53872176283 x speedup over Python\n"
          ]
        }
      ],
      "source": [
        "#| CHECK: GFLOP\n",
        "bench[matmul_tiled_unrolled_parallelized](python_gflops)"
      ]
    }
  ],
  "metadata": {
    "kernelspec": {
      "display_name": "Mojo",
      "language": "mojo",
      "name": "mojo-jupyter-kernel"
    },
    "language_info": {
      "codemirror_mode": {
        "name": "mojo"
      },
      "file_extension": ".mojo",
      "mimetype": "text/x-mojo",
      "name": "mojo"
    }
  },
  "nbformat": 4,
  "nbformat_minor": 4
}
